{
 "cells": [
  {
   "cell_type": "code",
   "id": "initial_id",
   "metadata": {
    "collapsed": true,
    "ExecuteTime": {
     "end_time": "2025-05-21T11:59:17.694984Z",
     "start_time": "2025-05-21T11:59:14.571855Z"
    }
   },
   "source": [
    "from datasets import load_dataset\n",
    "\n",
    "from MyHelper import *"
   ],
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "C:\\Users\\51165\\.conda\\envs\\e12\\Lib\\site-packages\\tqdm\\auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
      "  from .autonotebook import tqdm as notebook_tqdm\n"
     ]
    }
   ],
   "execution_count": 1
  },
  {
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-05-21T11:59:17.814374Z",
     "start_time": "2025-05-21T11:59:17.811817Z"
    }
   },
   "cell_type": "code",
   "source": [
    "model_name = Config.hfl_chinese_macbert_base\n",
    "dataset_name = \"clue/clue\""
   ],
   "id": "827eadb07685dfc1",
   "outputs": [],
   "execution_count": 2
  },
  {
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-05-21T11:59:28.182645Z",
     "start_time": "2025-05-21T11:59:17.819378Z"
    }
   },
   "cell_type": "code",
   "source": [
    "ds = load_dataset(dataset_name, \"c3\")\n",
    "ds.pop(\"test\")\n",
    "ds, ds[\"train\"][0:10]"
   ],
   "id": "2904fbd2387300b5",
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(DatasetDict({\n",
       "     train: Dataset({\n",
       "         features: ['id', 'context', 'question', 'choice', 'answer'],\n",
       "         num_rows: 11869\n",
       "     })\n",
       "     validation: Dataset({\n",
       "         features: ['id', 'context', 'question', 'choice', 'answer'],\n",
       "         num_rows: 3816\n",
       "     })\n",
       " }),\n",
       " {'id': [0, 1, 2, 3, 4, 5, 6, 7, 8, 9],\n",
       "  'context': [['男：你今天晚上有时间吗?我们一起去看电影吧?', '女：你喜欢恐怖片和爱情片，但是我喜欢喜剧片，科幻片一般。所以……'],\n",
       "   ['男：足球比赛是明天上午八点开始吧?', '女：因为天气不好，比赛改到后天下午三点了。'],\n",
       "   ['女：今天下午的讨论会开得怎么样?', '男：我觉得发言的人太少了。'],\n",
       "   ['男：我记得你以前很爱吃巧克力，最近怎么不吃了，是在减肥吗?', '女：是啊，我希望自己能瘦一点儿。'],\n",
       "   ['女：过几天刘明就要从英国回来了。我还真有点儿想他了，记得那年他是刚过完中秋节走的。',\n",
       "    '男：可不是嘛!自从我去日本留学，就再也没见过他，算一算都五年了。',\n",
       "    '女：从2000年我们在学校第一次见面到现在已经快十年了。我还真想看看刘明变成什么样了!',\n",
       "    '男：你还别说，刘明肯定跟英国绅士一样，也许还能带回来一个英国女朋友呢。'],\n",
       "   ['男：好久不见了，最近忙什么呢?',\n",
       "    '女：最近我们单位要搞一个现代艺术展览，正忙着准备呢。',\n",
       "    '男：你们不是出版公司吗?为什么搞艺术展览?',\n",
       "    '女：对啊，这次展览是我们出版的一套艺术丛书的重要宣传活动。'],\n",
       "   ['男：会议结束后，你记得把空调和灯都关了。', '女：好的，我知道了，明天见。'],\n",
       "   ['男：你出国读书的事定了吗?', '女：思前想后，还拿不定主意呢。'],\n",
       "   ['男：这件衣服我要了，在哪儿交钱?', '女：前边右拐就有一个收银台，可以交现金，也可以刷卡。'],\n",
       "   ['男：小李啊，你是我见过的最爱干净的学生。',\n",
       "    '女：谢谢教授夸奖。不过，您是怎么看出来的?',\n",
       "    '男：不管我叫你做什么，你总是推得干干净净。',\n",
       "    '女：教授，我……']],\n",
       "  'question': ['女的最喜欢哪种电影?',\n",
       "   '根据对话，可以知道什么?',\n",
       "   '关于这次讨论会，我们可以知道什么?',\n",
       "   '女的为什么不吃巧克力了?',\n",
       "   '现在大概是哪一年?',\n",
       "   '女的的公司为什么要做现代艺术展览?',\n",
       "   '他们最可能是什么关系?',\n",
       "   '女的是什么意思?',\n",
       "   '他们最可能在什么地方?',\n",
       "   '教授认为小李怎么样?'],\n",
       "  'choice': [['恐怖片', '爱情片', '喜剧片', '科幻片'],\n",
       "   ['今天天气不好', '比赛时间变了', '校长忘了时间'],\n",
       "   ['会是昨天开的', '男的没有参加', '讨论得不热烈', '参加的人很少'],\n",
       "   ['刷牙了', '要减肥', '口渴了', '吃饱了'],\n",
       "   ['2005年', '2010年', '2008年', '2009年'],\n",
       "   ['传播文化', '宣传新书', '推广现代艺术', '体现企业文化'],\n",
       "   ['同事', '司机和客人', '医生和病人'],\n",
       "   ['不想出国', '出国太难', '还在犹豫', '不想决定'],\n",
       "   ['医院', '迪厅', '商场', '饭馆'],\n",
       "   ['卫生习惯非常好', '做事的能力不够', '找借口拒绝做事', '记不住该做的事']],\n",
       "  'answer': ['喜剧片',\n",
       "   '比赛时间变了',\n",
       "   '讨论得不热烈',\n",
       "   '要减肥',\n",
       "   '2010年',\n",
       "   '宣传新书',\n",
       "   '同事',\n",
       "   '还在犹豫',\n",
       "   '商场',\n",
       "   '找借口拒绝做事']})"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "execution_count": 3
  },
  {
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-05-21T11:59:28.192947Z",
     "start_time": "2025-05-21T11:59:28.189943Z"
    }
   },
   "cell_type": "code",
   "source": [
    "# ds.save_to_disk(\"out/0204/ds\")\n",
    "# ds2 = load_from_disk(\"out/0204/ds\")\n",
    "# ds2"
   ],
   "id": "77c02215f41158b7",
   "outputs": [],
   "execution_count": 4
  },
  {
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-05-21T11:59:28.544871Z",
     "start_time": "2025-05-21T11:59:28.202099Z"
    }
   },
   "cell_type": "code",
   "source": [
    "tokenizer = AutoTokenizer.from_pretrained(model_name)\n",
    "tokenizer"
   ],
   "id": "dd52f5726af94245",
   "outputs": [
    {
     "data": {
      "text/plain": [
       "BertTokenizerFast(name_or_path='hfl/chinese-macbert-base', vocab_size=21128, model_max_length=1000000000000000019884624838656, is_fast=True, padding_side='right', truncation_side='right', special_tokens={'unk_token': '[UNK]', 'sep_token': '[SEP]', 'pad_token': '[PAD]', 'cls_token': '[CLS]', 'mask_token': '[MASK]'}, clean_up_tokenization_spaces=False, added_tokens_decoder={\n",
       "\t0: AddedToken(\"[PAD]\", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),\n",
       "\t100: AddedToken(\"[UNK]\", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),\n",
       "\t101: AddedToken(\"[CLS]\", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),\n",
       "\t102: AddedToken(\"[SEP]\", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),\n",
       "\t103: AddedToken(\"[MASK]\", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),\n",
       "}\n",
       ")"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "execution_count": 5
  },
  {
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-05-21T11:59:28.557848Z",
     "start_time": "2025-05-21T11:59:28.553844Z"
    }
   },
   "cell_type": "code",
   "source": "ds[\"train\"][:10][\"answer\"] + ds[\"train\"][:10][\"id\"]",
   "id": "cec14530359379fa",
   "outputs": [
    {
     "data": {
      "text/plain": [
       "['喜剧片',\n",
       " '比赛时间变了',\n",
       " '讨论得不热烈',\n",
       " '要减肥',\n",
       " '2010年',\n",
       " '宣传新书',\n",
       " '同事',\n",
       " '还在犹豫',\n",
       " '商场',\n",
       " '找借口拒绝做事',\n",
       " 0,\n",
       " 1,\n",
       " 2,\n",
       " 3,\n",
       " 4,\n",
       " 5,\n",
       " 6,\n",
       " 7,\n",
       " 8,\n",
       " 9]"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "execution_count": 6
  },
  {
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-05-21T11:59:28.594342Z",
     "start_time": "2025-05-21T11:59:28.569941Z"
    }
   },
   "cell_type": "code",
   "source": [
    "def process_function(examples):\n",
    "    context_list, question_choice_list, labels_list, = [], [], [],\n",
    "    for context, question, choice, answer in zip(examples[\"context\"], examples[\"question\"], examples[\"choice\"], examples[\"answer\"]):\n",
    "        for i in range(4):\n",
    "            context_list.append(\"\\n\".join(context))\n",
    "            question_choice = question + \" \" + choice[i] if i < len(choice) else \"不知道\"\n",
    "            question_choice_list.append(question_choice)\n",
    "        labels_list.append(choice.index(answer))\n",
    "\n",
    "    tokenized = tokenizer(context_list, question_choice_list, truncation=\"only_first\", max_length=256, padding=\"max_length\", return_tensors=\"pt\")\n",
    "    for k, v in tokenized.items():\n",
    "        v = v.view(-1, 4, 256)\n",
    "        tokenized[k] = v\n",
    "    tokenized[\"labels\"] = labels_list\n",
    "    return tokenized\n",
    "\n",
    "ds2 = ds.map(process_function, batched=True, batch_size=32)\n",
    "ds2"
   ],
   "id": "808c1b61cf17db56",
   "outputs": [
    {
     "data": {
      "text/plain": [
       "DatasetDict({\n",
       "    train: Dataset({\n",
       "        features: ['id', 'context', 'question', 'choice', 'answer', 'input_ids', 'token_type_ids', 'attention_mask', 'labels'],\n",
       "        num_rows: 11869\n",
       "    })\n",
       "    validation: Dataset({\n",
       "        features: ['id', 'context', 'question', 'choice', 'answer', 'input_ids', 'token_type_ids', 'attention_mask', 'labels'],\n",
       "        num_rows: 3816\n",
       "    })\n",
       "})"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "execution_count": 7
  },
  {
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-05-21T11:59:32.665096Z",
     "start_time": "2025-05-21T11:59:28.616274Z"
    }
   },
   "cell_type": "code",
   "source": "tokenizer.batch_decode(ds2[\"train\"][\"input_ids\"][1])\n",
   "id": "2e0802fc726002f7",
   "outputs": [
    {
     "data": {
      "text/plain": [
       "['[CLS] 男 ： 足 球 比 赛 是 明 天 上 午 八 点 开 始 吧? 女 ： 因 为 天 气 不 好 ， 比 赛 改 到 后 天 下 午 三 点 了 。 [SEP] 根 据 对 话 ， 可 以 知 道 什 么? 今 天 天 气 不 好 [SEP] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD]',\n",
       " '[CLS] 男 ： 足 球 比 赛 是 明 天 上 午 八 点 开 始 吧? 女 ： 因 为 天 气 不 好 ， 比 赛 改 到 后 天 下 午 三 点 了 。 [SEP] 根 据 对 话 ， 可 以 知 道 什 么? 比 赛 时 间 变 了 [SEP] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD]',\n",
       " '[CLS] 男 ： 足 球 比 赛 是 明 天 上 午 八 点 开 始 吧? 女 ： 因 为 天 气 不 好 ， 比 赛 改 到 后 天 下 午 三 点 了 。 [SEP] 根 据 对 话 ， 可 以 知 道 什 么? 校 长 忘 了 时 间 [SEP] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD]',\n",
       " '[CLS] 男 ： 足 球 比 赛 是 明 天 上 午 八 点 开 始 吧? 女 ： 因 为 天 气 不 好 ， 比 赛 改 到 后 天 下 午 三 点 了 。 [SEP] 不 知 道 [SEP] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD]']"
      ]
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "execution_count": 8
  },
  {
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-05-21T11:59:36.944347Z",
     "start_time": "2025-05-21T11:59:32.674573Z"
    }
   },
   "cell_type": "code",
   "source": [
    "from transformers import AutoModelForMultipleChoice, TrainingArguments, Trainer\n",
    "import numpy as np\n",
    "import evaluate\n",
    "\n",
    "model = AutoModelForMultipleChoice.from_pretrained(\"hfl/chinese-macbert-base\")\n",
    "accuracy = evaluate.load(\"accuracy\")\n",
    "\n",
    "def compute_metric(pred):\n",
    "    predictions, labels = pred\n",
    "    predictions = np.argmax(predictions, axis=-1)\n",
    "    return accuracy.compute(predictions=predictions, references=labels)\n",
    "\n",
    "args = TrainingArguments(\n",
    "    output_dir=\"./muliple_choice\",\n",
    "    per_device_train_batch_size=4,\n",
    "    per_device_eval_batch_size=32,\n",
    "    gradient_accumulation_steps=8,\n",
    "    logging_steps=10,\n",
    "    eval_strategy=\"steps\",\n",
    "    save_strategy=\"steps\",\n",
    "    max_steps=30,\n",
    "    load_best_model_at_end=True,\n",
    "    fp16=True\n",
    ")\n",
    "\n",
    "trainer = Trainer(\n",
    "    model=model,\n",
    "    args=args,\n",
    "    tokenizer=tokenizer,\n",
    "    train_dataset=ds2[\"train\"],\n",
    "    eval_dataset=ds2[\"validation\"],\n",
    "    compute_metrics=compute_metric\n",
    ")\n",
    "\n"
   ],
   "id": "10b7499980470b1b",
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Some weights of BertForMultipleChoice were not initialized from the model checkpoint at hfl/chinese-macbert-base and are newly initialized: ['classifier.bias', 'classifier.weight']\n",
      "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n",
      "C:\\Users\\51165\\AppData\\Local\\Temp\\ipykernel_19940\\3170570131.py:26: FutureWarning: `tokenizer` is deprecated and will be removed in version 5.0.0 for `Trainer.__init__`. Use `processing_class` instead.\n",
      "  trainer = Trainer(\n"
     ]
    }
   ],
   "execution_count": 9
  },
  {
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-05-21T12:01:19.248416Z",
     "start_time": "2025-05-21T11:59:36.952848Z"
    }
   },
   "cell_type": "code",
   "source": "trainer.train()",
   "id": "32eea7227f9769d2",
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ],
      "text/html": [
       "\n",
       "    <div>\n",
       "      \n",
       "      <progress value='30' max='30' style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
       "      [30/30 01:41, Epoch 0/1]\n",
       "    </div>\n",
       "    <table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       " <tr style=\"text-align: left;\">\n",
       "      <th>Step</th>\n",
       "      <th>Training Loss</th>\n",
       "      <th>Validation Loss</th>\n",
       "      <th>Accuracy</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <td>10</td>\n",
       "      <td>1.393900</td>\n",
       "      <td>1.347069</td>\n",
       "      <td>0.353774</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>20</td>\n",
       "      <td>1.354200</td>\n",
       "      <td>1.317731</td>\n",
       "      <td>0.368449</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>30</td>\n",
       "      <td>1.331100</td>\n",
       "      <td>1.308844</td>\n",
       "      <td>0.367400</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table><p>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/plain": [
       "TrainOutput(global_step=30, training_loss=1.3597391764322917, metrics={'train_runtime': 102.171, 'train_samples_per_second': 9.396, 'train_steps_per_second': 0.294, 'total_flos': 505168690544640.0, 'train_loss': 1.3597391764322917, 'epoch': 0.08086253369272237})"
      ]
     },
     "execution_count": 10,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "execution_count": 10
  },
  {
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-05-21T12:38:35.525164Z",
     "start_time": "2025-05-21T12:38:35.437776Z"
    }
   },
   "cell_type": "code",
   "source": [
    "class MultipleChoicePipeline:\n",
    "\n",
    "    def __init__(self, model, tokenizer):\n",
    "        self.model = model\n",
    "        self.tokenizer = tokenizer\n",
    "        self.device = model.device\n",
    "\n",
    "    def preprocess(self, context, question, choices):\n",
    "        cs, qcs = [], []\n",
    "\n",
    "        for choice in choices:\n",
    "            cs.append(context)\n",
    "            qcs.append(question + \" \" + choice)\n",
    "\n",
    "        tokenized = tokenizer(cs, qcs, return_tensors=\"pt\")\n",
    "        return tokenized\n",
    "\n",
    "    def predict(self, inputs):\n",
    "        print(inputs.keys())\n",
    "        inputs = {k: v.unsqueeze(0).cuda() for k , v in inputs.items()}\n",
    "        return self.model(**inputs).logits\n",
    "\n",
    "    def postprocess(self, logits, choices):\n",
    "        prediction = logits.argmax(-1)\n",
    "        return choices[prediction]\n",
    "\n",
    "    def __call__(self, context, question, choices):\n",
    "        inputs = self.preprocess(context, question, choices)\n",
    "        logits = self.predict(inputs)\n",
    "        result = self.postprocess(logits, choices)\n",
    "        return result\n",
    "pipe = MultipleChoicePipeline(model, tokenizer)\n",
    "pipe(\"小明在北京上班\", \"小明在哪里上班？\", [\"北京\", \"上海\", \"河北\", \"海南\", \"河北\"])\n"
   ],
   "id": "57aaa8b0b0263d16",
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "dict_keys(['input_ids', 'token_type_ids', 'attention_mask'])\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "'北京'"
      ]
     },
     "execution_count": 22,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "execution_count": 22
  },
  {
   "metadata": {},
   "cell_type": "code",
   "outputs": [],
   "execution_count": null,
   "source": "",
   "id": "d82ad95972768622"
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 2
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython2",
   "version": "2.7.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
